from collections import defaultdict
import torch
import torch.nn as nn
from utils.utils import Task_type_handler

class GATEModel(nn.Module):
    """ GT baseline model """

    def __init__(self, backbone_front: nn.Module, backbone_back: nn.Module, bottlenecks: nn.Module,
                 transform: nn.Module, inv_transform: nn.Module, heads: nn.Module, d_num, perb_ratio):
        """Total model architecture for geometric transfer

        Args:
            backbone_front (nn.Module): Front model
            backbone_back (nn.Module): Model back for each task
            bottlenecks (nn.Module): Encoder
            transform (nn.Module): To flat space
            inv_transform (nn.Module): Back to the original space
            heads (nn.Module): Downstream task
            tasks (Task_type_handler): task generator
            task_sample_num (int) : Number of task samples
            d_num (int) : Number of perturbation
        """
        super(GATEModel, self).__init__()
        self.backbone_front = backbone_front
        self.backbone_back = backbone_back
        self.bottlenecks = bottlenecks
        self.transform = transform
        self.inv_transform = inv_transform
        self.heads = heads
        self.d_num = d_num
        self.perb_ratio = perb_ratio
        self.rand = torch.rand(self.d_num)

        def normal_perturbation(x, task, idx):
            if type(self.perb_ratio) == dict:
                return x + (self.rand[idx-1].to(x) * self.perb_ratio[task] * x.std(-1).unsqueeze(-1)) + torch.tensor([0.0001]).to(x)
            else:
                return x + (self.rand[idx-1].to(x) * self.perb_ratio * x.std(-1).unsqueeze(-1)) + torch.tensor([0.0001]).to(x)

        self.perturbation = normal_perturbation

    def forward(self, input, task_type:str):
        pre_forward_outs = defaultdict(dict)

        outs = self.backbone_front(input)

        for i in self.backbone_back.keys():
            backbone_outs , bottleneck_outs, transfer_outs, main_inv_outs = [], [], [], []
            if self.d_num > 0:
                for j in range(self.d_num + 1):
                    if j == 0:
                        backbone_outs.append(self.backbone_back[i](outs))
                        bottleneck_outs.append(self.bottlenecks[i](backbone_outs[-1]))
                        transfer_outs.append(self.transform[i](bottleneck_outs[-1][0]))
                        main_inv_outs.append(self.inv_transform[i](transfer_outs[-1]))
                    else:
                        perturbed = self.perturbation(outs[0], i, j)
                        new_outs = [perturbed] + list(outs)[1:]

                        backbone_outs.append(self.backbone_back[i](tuple(new_outs)))
                        bottleneck_outs.append(self.bottlenecks[i](backbone_outs[-1]))
                        transfer_outs.append(self.transform[i](bottleneck_outs[-1][0]))
                        main_inv_outs.append(self.inv_transform[i](transfer_outs[-1]))

            pre_forward_outs[i]['backbone_out'] = backbone_outs
            pre_forward_outs[i]['bottleneck_out'] = bottleneck_outs

            pre_forward_outs[i]['transfer_out'] = transfer_outs
            pre_forward_outs[i]['main_inv_out'] = main_inv_outs
            
        pre_forward_outs["y"] = input.y

        sub_task = [a for a in self.backbone_back.keys() if a != task_type][0]

        pre_result_dict = {}
        for i in self.backbone_back.keys():
            pre_result_dict[i] = {'down': [], 'encoder': [], 'decoder': [], 'y_de': [], 'flat': [], 'ori': [], 'map': [], 'map_down': [], 'ds': []}

        if self.d_num > 0:
            for j in range(self.d_num + 1):
                sub_trans_outs = []
                sub_trans_outs.append(pre_forward_outs[sub_task]['transfer_out'][j])
        else:
            raise('0 perturbations')

        pre_forward_outs[task_type]['sub_inv_out'] = self.inv_transform[task_type](torch.cat(sub_trans_outs))

        for i in self.backbone_back.keys():
            if self.d_num > 0:
                for j in range(self.d_num + 1):
                    pre_result_dict[i]['y_de'].append(pre_forward_outs[i]['backbone_out'][j])
                    pre_result_dict[i]['encoder'].append(pre_forward_outs[i]['bottleneck_out'][j][0])
                    pre_result_dict[i]['decoder'].append(pre_forward_outs[i]['bottleneck_out'][j][1])
                                    
                    pre_result_dict[i]['flat'].append(pre_forward_outs[i]['transfer_out'][j])
                    pre_result_dict[i]['ori'].append(pre_forward_outs[i]['main_inv_out'][j])
                
        pre_result_dict[task_type]['map'] = pre_forward_outs[task_type]['sub_inv_out']
        pre_result_dict[task_type]['map_down'] = self.heads[task_type](pre_forward_outs[task_type]['sub_inv_out'])

        for i in self.backbone_back.keys():
            pre_result_dict[i]['ds'] = torch.mean((pre_result_dict[i]['flat'][0] - torch.stack(pre_result_dict[i]['flat'][1:])) ** 2, dim=-1)

        pre_result_dict[task_type]['down'] = self.heads[task_type](pre_result_dict[task_type]['encoder'][0])
        pre_result_dict[task_type]['y'] = pre_forward_outs['y']
        return pre_result_dict